from . import db
import json
import copy

from common.utils.format_util import py_to_java
from common.log import log_handler

log = log_handler.LogHandler().get_log()


class Task(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(100))
    project_id = db.Column()
    pipeline_id = db.Column()
    parent_id = db.Column()
    user_id = db.Column()
    type = db.Column()
    data_json = db.Column()

    def __init__(self, name, project_id, type, data_json):
        self.name = name
        self.project_id = project_id
        self.type = type
        self.data_json = data_json

    def update_data_json(self, data_json):
        self.data_json = data_json


def update_task(taskId, df, target, update, results={}):
    try:
        task = Task.query.filter_by(id=taskId).first()
        parent_task = Task.query.filter_by(id=task.parent_id).first()
    except:
        db.session.rollback()
        log.info("session rollback")
        task = Task.query.filter_by(id=taskId).first()
        parent_task = Task.query.filter_by(id=task.parent_id).first()
    data_json = json.loads(task.data_json)
    parent_data_json = json.loads(parent_task.data_json)
    # log.info(data_json)
    log.info("PREVIOUS ORIGINAL DATA JSON__________")
    columns_df_pred = df.columns.to_list()
    data_json_output = copy.deepcopy(parent_data_json["output"])
    # data_json["output"] = data_json_output

    if len(data_json_output) == 0:
        return 0
    new_columns = []
    old_columns = copy.deepcopy(data_json_output[0]["tableCols"])
    for new_column in columns_df_pred:
        if new_column not in old_columns:
            new_columns.append(new_column)
    # try:
    #
    # 	data_json_output[0]["tableCols"] = old_columns + new_columns
    # except:
    data_json_output[0]["tableCols"] = columns_df_pred

    semantic_new = {}
    semantic_old = copy.deepcopy(data_json_output[0]["semantic"])
    for key, value in semantic_old.items():
        new_key = key.replace(" ", "_").replace("/", "_").replace("-", "_")
        semantic_new[new_key] = value
    data_json_output[0]["semantic"] = semantic_new
    data_json_output[0]["tableName"] = target
    for key in results:
        data_json[key] = results[key]

    if update:
        for new_column in new_columns:
            data_json_output[0]["semantic"][new_column] = "null"
            index = data_json_output[0]["tableCols"].index(new_column)
            data_json_output[0]["columnTypes"].insert(index, "REAL")
    data_json["output"] = data_json_output
    data_json = py_to_java(str(data_json))
    # log.info(data_json)
    task.update_data_json(data_json)
    # log.info(results)
    db.session.commit()
    return 1


def get_col_type(taskId):
    task = Task.query.filter_by(id=taskId).first()
    data_json = json.loads(task.data_json)
    cols = data_json["input"][0]["tableCols"]
    types = data_json["input"][0]["columnTypes"]
    ret = {}
    for col, type in zip(cols, types):
        if col == "_record_id_":
            continue
        ret[col] = type
    return ret


def filter_numeric_col(taskId):
    col_type = get_col_type(taskId)
    numeric_cols = []
    non_numeric_cols = []
    for col in col_type:
        if col_type[col] in ["BIGINT", "INTEGER", "DECIMAL", "FLOAT", "SMALLINT", "NUMERIC", "DOUBLE", "REAL"]:
            numeric_cols.append(col)
        else:
            non_numeric_cols.append(col)
    return numeric_cols, non_numeric_cols
